import pickle

import torch.nn.functional as F
import torch
import numpy as np
from tqdm import tqdm

from data_utils import PlainLoader as MyLoader
from data_utils import PlainListLoader
from utils import evaluate

def distill_train(device, features, teacher_emb, g, model, train_conf, verbose=True): # pylint: disable=redefined-outer-name
    # create sampler & dataloader
    train_idx = g.train_idx.to(device)
    val_idx = g.val_idx.to(device)
    test_idx = g.test_idx.to(device)
    features = features.to(device)
    labels = torch.tensor(g.ndata['label']).to(device)
    if teacher_emb is None:
        teacher_emb = torch.from_numpy(np.arange(g.num_nodes()))
    teacher_emb = teacher_emb.to(device)

    train_dataloader_distill = PlainListLoader([features, labels, teacher_emb], train_conf["batch_size"], torch.cat((train_idx, val_idx, test_idx)))
    train_dataloader = MyLoader(features, labels, train_conf["batch_size"], train_idx)
    val_dataloader = MyLoader(features, labels, train_conf["batch_size"], val_idx)
    test_dataloader = MyLoader(features, labels, train_conf["batch_size"], test_idx)

    opt = torch.optim.Adam(model.parameters(), lr=train_conf["lr"],
                           weight_decay=train_conf['weight_decay'])
    # opt = torch.optim.LBFGS(model.parameters(), lr=train_conf["lr"],
    #                         weight_decay=train_conf['weight_decay'])

    dis_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
    best_state, best_val, best_epoch = None, 0, 0
    patience_cnt = 0
    for epoch in tqdm(range(train_conf["epoch"])):
        model.train()
        total_loss = 0
        # for distillation
        if train_conf['lam_distill'] > 0:
            for _, (x, y, z) in enumerate(train_dataloader_distill):
                y_hat = model(x)
                loss = dis_loss(y_hat.log_softmax(dim=1), z) * train_conf['lam_distill']
                opt.zero_grad()
                loss.backward()
                opt.step()
                total_loss += loss.item()
        # for normal loss
        if 1 - train_conf['lam_distill'] > 0:
            for _, (x, y) in enumerate(train_dataloader):
                y_hat = model(x)
                loss = F.cross_entropy(y_hat, y) *  (1 - train_conf['lam_distill'])
                opt.zero_grad()
                loss.backward()
                opt.step()
                total_loss += loss.item()

        acc = evaluate(model, val_dataloader) # pylint: disable=redefined-outer-name
        # print("Epoch {:04d} | ACC {:.4f}"
        #      .format(epoch, acc.item()))
        if acc.item() > best_val:
            best_val = acc.item()
            best_state = pickle.dumps(model.state_dict())
            best_epoch = epoch
            patience_cnt = 0
        else:
            patience_cnt += 1
        # add test and see
        if verbose:
            test_acc = evaluate(model, test_dataloader)
            print(f"    epoch: {epoch}, val acc: {acc.item()}, test acc: {test_acc.item()}")
        if patience_cnt >= train_conf["patience"]:
            print("Triggered early stopping.")
            best_epoch = epoch
            best_state = pickle.dumps(model.state_dict())
            break

    print(f"Best epoch {best_epoch:05d} | f1_micro {best_val:.4f}")

    return best_state